from torch import nn
import torch
from config import Config
import time


class BaseNet(nn.Module):
    """
    提供load和save接口
    """

    def __init__(self):
        super(BaseNet, self).__init__()

    def load(self, file):
        self.load_state_dict(torch.load(file))

    def save(self, file=None):
        """
        保存模型到指定位置。若不指定位置，则在固定目录下保存时间名的文件
        :param file: 保存文件的路径
        :return: None
        """
        if file is None:
            file = time.strftime(Config.save_dir + '_%y%m%d_%H_%M_%S.pth')
        torch.save(self.state_dict(), file)
        # print('model saved at {}'.format(file))
